import os
import pandas as pd
import numpy as np
from scipy.spatial.distance import jensenshannon

# 1. 路径配置
MEDIA_FOLDER   = "Media_Monthly_Counts_Quad&Event (Added)"
BOTH_DATA_FILE = "Monthly_Counts_Quad&Event.xlsx"
OUTPUT_FILE    = "JSD_Quad&Event.xlsx"

# 2. 构造完整月份索引（2014-01 至 2023-12）
FULL_MONTHS = pd.period_range("2014-01", "2023-12", freq="M")

# 3. 通用加载函数：读取原始计数，解析索引，合并、补齐但不归一化
def load_counts(path, sheet, infer_format=True):
    df = pd.read_excel(path, sheet_name=sheet, engine="openpyxl")
    # 解析 Month列
    s = df['Month'].astype(str).str.strip()
    if infer_format:
        dt = pd.to_datetime(s, infer_datetime_format=True, errors='coerce')
    else:
        dt = pd.to_datetime(s, format="%Y-%m", errors='coerce')
    df['Period'] = dt.dt.to_period('M')
    df = df.drop(columns=['Month']).set_index('Period')
    # 强制数值型 & 合并同月
    df = df.apply(pd.to_numeric, errors='coerce').fillna(0)
    df = df.groupby(df.index).sum()
    # 补齐完整月份，缺失填 0
    df = df.reindex(FULL_MONTHS, fill_value=0)
    return df

# 4. 读取「全部样本」原始计数，并记录列标签
glob_q_counts = load_counts(BOTH_DATA_FILE, "QuadClass", infer_format=False)  # YYYYMM
glob_e_counts = load_counts(BOTH_DATA_FILE, "EventRootCode", infer_format=False)
quad_cols = glob_q_counts.columns
evt_cols  = glob_e_counts.columns

# 5. 计算每家媒体的平均 Jensen–Shannon 距离
results = []
epsilon = 1e-8

for fname in os.listdir(MEDIA_FOLDER):
    if not fname.lower().endswith(".xlsx"):
        continue
    media = os.path.splitext(fname)[0]
    path  = os.path.join(MEDIA_FOLDER, fname)

    # 5.1 载入媒体计数
    mq = load_counts(path, "QuadClass")
    me = load_counts(path, "EventRootCode")

    # 5.2 补齐列，保持与全量一致
    mq = mq.reindex(columns=quad_cols, fill_value=0)
    me = me.reindex(columns=evt_cols,  fill_value=0)

    # 5.3 平滑并归一化为概率分布
    pq = (mq + epsilon).div((mq + epsilon).sum(axis=1), axis=0).fillna(0)
    pe = (me + epsilon).div((me + epsilon).sum(axis=1), axis=0).fillna(0)
    gq = (glob_q_counts + epsilon).div((glob_q_counts + epsilon).sum(axis=1), axis=0).fillna(0)
    ge = (glob_e_counts + epsilon).div((glob_e_counts + epsilon).sum(axis=1), axis=0).fillna(0)

    # 5.4 仅对两边都有的月份计算 JSD
    common_q = pq.index.intersection(gq.index)
    common_e = pe.index.intersection(ge.index)

    js_q = [jensenshannon(pq.loc[m].values, gq.loc[m].values, base=2) for m in common_q]
    js_e = [jensenshannon(pe.loc[m].values, ge.loc[m].values, base=2) for m in common_e]

    # 5.5 求平均（若无共同月则设 NaN）
    mean_q = float(np.nanmean(js_q)) if js_q else np.nan
    mean_e = float(np.nanmean(js_e)) if js_e else np.nan

    results.append({
        "Media":    media,
        "JS_Quad":  mean_q,
        "JS_Event": mean_e,
        "JS_Mean":  np.nanmean([mean_q, mean_e])
    })

# 6. 输出结果
pd.DataFrame(results).set_index("Media") \
  .to_excel(OUTPUT_FILE)

print(f"✅ 完成！Jensen–Shannon 距离结果保存在 {OUTPUT_FILE}")
